Skip to content

[ExecuTorch][WebGPU] Dynamic resize hook for SDPA (live seq-len S)#20580

Open
JulianCloudNTH wants to merge 2 commits into
gh/JulianCloudNTH/72/basefrom
gh/JulianCloudNTH/72/head
Open

[ExecuTorch][WebGPU] Dynamic resize hook for SDPA (live seq-len S)#20580
JulianCloudNTH wants to merge 2 commits into
gh/JulianCloudNTH/72/basefrom
gh/JulianCloudNTH/72/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Make sdpa_with_kv_cache serve any live seq-len S from one graph (batched prefill S=K and decode S=1).

Problem: the existing dynamic path only reacted to a live input_pos (decode), with S captured at build time. It rewrote the QK dispatch (which depends on context_len) but left update_cache, softmax, and AV sized for the build-time S. Under a dynamic seq-len S (one graph serving prefill and decode), kv_numel, the QK/AV tile grids, and the softmax row count all depend on S and were stale.

Solution: a single recompute hook driven by either a live S (q tensor resize) or a live input_pos (SymInt), recomputing every per-step quantity from the live shape.

  • Before: hook keyed only on input_pos; recomputes ctx + QK count; S fixed.
  • After: hook keyed on q (always) and input_pos (when SymInt); reads live S from cur_dims(q) and live pos, recomputes all five dispatches' counts + UBOs (update_cache K/V, QK, softmax, AV), and sets the output cur_dims to q's.

Implementation:

  • Capture the update_cache/softmax/AV dispatch indices (previously only QK) so their workgroup counts can be rewritten per step.
  • QK/AV workgroup counts use the landed register-tiled grids (Hq*ceil(S/TM)*ceil(ctx-or-D/TN)); softmax is one workgroup per Hq*S row.
  • Register the hook on q unconditionally — inert until q is resized, so a static graph is byte-identical.
  • Mirrors Vulkan DynamicDispatchNode (recompute workgroups per execute); scratch is sized at build (S=max, ctx=Cmax) so buffers never move and bind groups stay valid.

Constraints: fp32-only, batch=1, GQA, is_causal=true, D%4==0 invariants unchanged; the static / decode-only paths are unaffected (the q hook never fires without a resize).

Co-authored-with: Claude Code.

Differential Revision: D109906097

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 28, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20580

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Pending

As of commit 6ee7775 with merge base 0cef6de (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 28, 2026
@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review and check for any areas or opportunities for modularization

@claude

claude Bot commented Jun 29, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 2m 39s —— View job


Review: Dynamic resize hook for SDPA (live seq-len S)

  • Read the changed file and PR diff
  • Analyze correctness of the dynamic resize hook
  • Check for modularization opportunities
  • Post review feedback

The change is well-conceived: keying the recompute on the q tensor (always) and input_pos (only when SymInt), reusing the make_*_params builders, and sizing scratch at the build-time max so bind groups stay valid. The static / decode-only paths really are left untouched. A few notes below — the headline is the modularization opportunity you asked about, plus one correctness guard worth adding.


Modularization (the main ask)

1. The hook re-implements the five-dispatch build, line for line. The biggest win here. Dispatches 1–5 (Sdpa.cpp:449-561) and the hook body (Sdpa.cpp:590-650) compute the same quantities — kv_numel, kv_off, the QK/AV tile grids (Hq * div_up(s, TM) * div_up(N, TN)), the softmax row count, and all five UBO structs — from (s, pos). The build path computes them once for the build-time (S, input_pos); the hook recomputes them for the live (s, pos). That's two copies of the same arithmetic that must be kept in sync by hand (note the QK grid formula already appears three times in the file).

Extracting a single struct + function that maps (s, pos) → counts and params would collapse both sites:

struct SdpaStepCounts {
  UpdateCacheParams uc;
  AttnWeightsParams qk;
  SoftmaxParams softmax;
  ComputeOutParams av;
  uint32_t uc_wgc, qk_wgc, softmax_wgc, av_wgc;
};

// Computes every per-step quantity from the live (s, pos); shared by the
// initial build and the resize hook so the two can never drift.
static SdpaStepCounts compute_sdpa_step(
    WGPUDevice device, int64_t s, int64_t pos, const SdpaDims& d,
    uint32_t uc_wg, uint32_t qk_wg, uint32_t av_wg);

The build path calls it once to populate the initial UBOs/workgroup_count_x; the hook calls it and just writes the buffers + dispatch_at(...).workgroup_count_x. This removes ~50 lines of duplicated logic and makes the "recompute all five" invariant structural rather than maintained by hand. Fix this →

2. The 24-entry lambda capture list (Sdpa.cpp:566-590). Once #1 lands, group the captured state into two small structs — a SdpaDims (Hq, Hkv, D, Cmax, g, scale, qn) and the dispatch index/buffer handles — and capture those by value. The lambda drops to ~4 captures and the dependency surface becomes legible.

3. A tile-grid helper. Hq * div_up(s, kSdpaTileM) * div_up(N, kSdpaTileN) appears for QK (build + hook) and AV (build + hook). A tile_workgroup_count(device, Hq, s, N, wg, label) folds those four sites into one and documents the register-tiled grid in a single place.


Correctness

Scratch-capacity guard is implicit. The scratch buffers are sized at build for Hq * S_build * Cmax floats (Sdpa.cpp:428-434). The hook validates ctx <= Cmax and aw_floats <= UINT32_MAX (Sdpa.cpp:599,608) but never checks that the live s doesn't exceed the build-time S. Out-of-bounds into attn_weights is avoided only because resize_input clamps q's live dims to its build-time max, so s <= S_build holds transitively. That's a real invariant the hook depends on but doesn't state. Since the existing aw_floats > UINT32_MAX check is already a defensive guard, consider making the binding explicit — e.g. validate Hq * s * ctx <= aw_cap_floats (capture aw_cap_floats) — so a future caller that resizes scratch independently can't silently overflow.

Redundant double-fire when both triggers are dirty. With dynamic_pos, the same closure is registered as both a tensor hook on q and a SymInt hook on input_pos (Sdpa.cpp:655-657). In propagate_resize the SymInt hooks run, then the tensor hooks (WebGPUGraph.cpp:191-209), so a step that dirties both input_pos and q's dims runs the hook twice (redundant wgpuQueueWriteBuffer calls). It's idempotent so not a bug, and the common decode (pos changes, s=1 stays) / prefill (q resizes) cases only trip one path — just worth a one-line comment noting the intentional idempotency, since registering one function on two triggers invites the question.


Minor

  • Sdpa.cpp:591 gr.cur_dims(q_id)[qn - 3] relies on the live rank matching the captured build-time qn. Safe given set_cur_dims preserves rank, but an assert(gr.cur_dims(q_id).size() == qn) would document the assumption cheaply.
  • The error string at Sdpa.cpp:609 ("exceeds uint32") differs slightly from the build-path version at :484 ("exceeds uint32 max"). Trivial, but a shared helper (Add support for quantized LeakyReLU #1) would unify these too.

Overall this is a clean extension of the dynamic path and the Vulkan DynamicDispatchNode parallel is apt. The duplication between build and hook is the one thing I'd address before landing — it's exactly the kind of two-copies-must-agree arithmetic that drifts.
· gh/JulianCloudNTH/72/head

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants